from collections import OrderedDict
from text.symbols import symbols
import torch

from tools.log import logger
import utils
from models import SynthesizerTrn
import os


def copyStateDict(state_dict):
    if list(state_dict.keys())[0].startswith("module"):
        start_idx = 1
    else:
        start_idx = 0
    new_state_dict = OrderedDict()
    for k, v in state_dict.items():
        name = ",".join(k.split(".")[start_idx:])
        new_state_dict[name] = v
    return new_state_dict


def removeOptimizer(config: str, input_model: str, ishalf: bool, output_model: str):
    hps = utils.get_hparams_from_file(config)

    net_g = SynthesizerTrn(
        len(symbols),
        hps.data.filter_length // 2 + 1,
        hps.train.segment_size // hps.data.hop_length,
        n_speakers=hps.data.n_speakers,
        **hps.model,
    )

    optim_g = torch.optim.AdamW(
        net_g.parameters(),
        hps.train.learning_rate,
        betas=hps.train.betas,
        eps=hps.train.eps,
    )

    state_dict_g = torch.load(input_model, map_location="cpu")
    new_dict_g = copyStateDict(state_dict_g)
    keys = []
    for k, v in new_dict_g["model"].items():
        if "enc_q" in k:
            continue  # noqa: E701
        keys.append(k)

    new_dict_g = (
        {k: new_dict_g["model"][k].half() for k in keys}
        if ishalf
        else {k: new_dict_g["model"][k] for k in keys}
    )

    torch.save(
        {
            "model": new_dict_g,
            "iteration": 0,
            "optimizer": optim_g.state_dict(),
            "learning_rate": 0.0001,
        },
        output_model,
    )


if __name__ == "__main__":
    import argparse

    parser = argparse.ArgumentParser()
    parser.add_argument("-c", "--config", type=str, default="configs/config.json")
    parser.add_argument("-i", "--input", type=str)
    parser.add_argument("-o", "--output", type=str, default=None)
    parser.add_argument(
        "-hf", "--half", action="store_true", default=False, help="Save as FP16"
    )

    args = parser.parse_args()

    output = args.output

    if output is None:
        import os.path

        filename, ext = os.path.splitext(args.input)
        half = "_half" if args.half else ""
        output = filename + "_release" + half + ext

    removeOptimizer(args.config, args.input, args.half, output)
    logger.info(f"压缩模型成功, 输出模型: {os.path.abspath(output)}")
